from collections import defaultdict
import jax
import jax.numpy as jnp
import json

from functools import partial
import numpy as np

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

@jax.jit
def process_batch(params, batch, k):
    """
    Process a batch of inputs to get sparse codes and reconstructions.
    
    Args:
        params: Model parameters
        batch: Input batch
        k: Number of active units
        
    Returns:
        Tuple of (sparse_codes, reconstructions)
    """
    # Encoder forward pass
    batch_minus_bias = batch - params["tied_bias"]
    encoded = jnp.dot(batch_minus_bias, params["encoder"]["weights"]) + params["encoder"]["bias"]
    
    # Compute top-k mask
    sorted_latents = -jnp.sort(-jnp.abs(encoded), axis=-1)
    k_th_largest = jnp.expand_dims(sorted_latents[..., k - 1], axis=-1)
    topk_mask = jnp.abs(encoded) >= k_th_largest
    
    # Apply the top-k mask to get sparse codes
    sparse_codes = jnp.where(topk_mask, encoded, 0)
    
    # Decoder forward pass
    decoded = jnp.dot(sparse_codes, params["decoder"]["weights"]) + params["decoder"]["bias"]
    reconstructions = decoded + params["tied_bias"]
    
    return sparse_codes, reconstructions

def get_sparse_representations_and_reconstructions(model_params, inputs, k, batch_size=1024):
    """
    Get sparse representations and reconstructions for the entire dataset.
    
    Args:
        model_params: Model parameters
        inputs: Input data
        k: Number of active units
        batch_size: Batch size for processing
        
    Returns:
        Tuple of (sparse_codes, reconstructions)
    """
    # Process data in batches
    num_samples = inputs.shape[0]
    codes = []
    reconstructions = []
    
    # Create a partially applied function with fixed parameters
    batch_processor = partial(process_batch, model_params, k=k)
    
    for i in range(0, num_samples, batch_size):
        batch = inputs[i:min(i+batch_size, num_samples)]
        batch_code, batch_reconstruction = batch_processor(batch)
        codes.append(np.array(batch_code))
        reconstructions.append(np.array(batch_reconstruction))
        if (i + batch_size) % (10 * batch_size) == 0:
            print(f"Processed {min(i+batch_size, num_samples)}/{num_samples} samples")
    
    return np.concatenate(codes), np.concatenate(reconstructions)

def find_top_k_embeddings_cosine_similarity(embeddings, query_vector, k=20):
    """
    Find the top k embeddings with the highest cosine similarity to a query vector
    using fully vectorized NumPy operations.
    
    Parameters:
        query_vector: 1D array (the query vector).
        embeddings: 2D array where each row is an embedding.
        k: number of top embeddings to return.
        vocab_list: list of vocabulary words corresponding to embeddings.
    
    Returns:
        top_k_words: words corresponding to the top k embeddings with highest cosine similarity.
        top_similarities: cosine similarities for the top k embeddings.
    """
    # Ensure inputs are NumPy arrays
    query_vector = np.asarray(query_vector, dtype=np.float32)
    embeddings = np.asarray(embeddings, dtype=np.float32)
    
    # Normalize the query vector once (instead of calculating norm separately)
    query_norm = np.linalg.norm(query_vector)
    normalized_query = query_vector / (query_norm + 1e-8)
    
    # Normalize all embeddings in one operation
    # This uses broadcasting to divide each row by its norm
    embeddings_norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized_embeddings = embeddings / (embeddings_norms + 1e-8)
    
    # Calculate dot products between normalized vectors
    # This is equivalent to cosine similarity since vectors are normalized
    cosine_similarities = np.dot(normalized_embeddings, normalized_query)
    
    # Get indices of the top k highest similarities
    # np.argpartition is more efficient than argsort when we only need top k
    if k >= len(cosine_similarities):
        top_k_indices = np.argsort(cosine_similarities)[::-1]
    else:
        top_k_indices = np.argpartition(cosine_similarities, -k)[-k:]
        # Sort just the top k in descending order
        top_k_indices = top_k_indices[np.argsort(cosine_similarities[top_k_indices])[::-1]]
    
    # Get the top similarities
    # top_similarities = cosine_similarities[top_k_indices]

    # top_k_words = [vocab_list[idx].strip() for idx in top_k_indices]
    return top_k_indices

def find_associated_words(z, word_code, top_k=10):    
    # Identify non-zero entries in the sparse code which correspond to active topics
    active_topics = jnp.nonzero(word_code)[0]
    if active_topics.size == 0:
        print(f"No active topics associated with the word: '{word}'")
        return
    
    # print(f"{word} has support on {active_topics.shape[0]} topics")
    # print("-----------------")
    
    # Sort active topics by their absolute coefficient values
    topic_weights = jnp.abs(word_code[active_topics])
    sorted_indices = jnp.argsort(topic_weights)[::-1]  # Sort in descending order
    sorted_topics = active_topics[sorted_indices]
    
    # For each active topic, find the words with the highest coefficients for this topic
    for topic in sorted_topics:
        topic_weight = word_code[topic]
        print(f"Topic: {topic}")
        print(f"Weight: {topic_weight:.5f}")
        
        # Get the coefficients for this topic across all words
        topic_coefficients = z[:, topic]
        
        # Find the indices of the top k words with the highest coefficients in this topic
        top_word_indices = jax.lax.top_k(jnp.abs(topic_coefficients), top_k)[1]
        
        # Get the top words and their coefficients, ensuring words are properly trimmed
        top_words_and_coeffs = [(vocab_list[idx].strip(), topic_coefficients[idx]) for idx in top_word_indices]
        
        print(" Top words for this topic:")
        for word, coeff in top_words_and_coeffs:
            print(f"  {word}: {coeff:.5f}")
        print("-----------------")


# class SparseCodeMatcher:
#     def __init__(self, z=None):
#         """
#         Initialize the sparse code matcher.
        
#         Args:
#             z: Optional tensor of shape [num_codes, code_dim] containing sparse codes
#         """
#         self.index = defaultdict(list)
#         if z is not None:
#             self.build_index(z)
    
#     def build_index(self, z):
#         """
#         Build an index from sparse codes for efficient retrieval.
        
#         Args:
#             z: Numpy array or tensor of shape [num_codes, code_dim] containing sparse codes
#         """
#         self.index = defaultdict(list)
        
#         # Handle potential numpy array
#         if isinstance(z, np.ndarray):
#             # For each sparse code
#             for i in range(z.shape[0]):
#                 # Find non-zero entries (numpy version)
#                 nonzero_indices = np.nonzero(z[i])[0]
                
#                 # Add this code's index to each entry's list
#                 for pos in nonzero_indices:
#                     self.index[int(pos)].append(i)

#         else:
#             # Assume it's a torch tensor
#             for i in range(z.shape[0]):
#                 # Find non-zero entries (torch version)
#                 nonzero_indices = torch.nonzero(z[i], as_tuple=True)[0]
                
#                 # Add this code's index to each entry's list
#                 for pos in nonzero_indices:
#                     self.index[pos.item()].append(i)
        
#         return self
    
#     def retrieve_similar_codes(self, query_vector, max_codes=4000):
#         """
#         Retrieve codes similar to the query vector.
        
#         Args:
#             query_vector: Sparse query vector
#             max_codes: Maximum number of codes to return
            
#         Returns:
#             List of code indices
#         """
#         # Get non-zero entries and their values
#         nonzero_indices = torch.nonzero(query_vector, as_tuple=True)[0]
#         values = query_vector[nonzero_indices]
        
#         # Sort entries by value in descending order
#         sorted_indices = torch.argsort(values, descending=True)
#         sorted_positions = nonzero_indices[sorted_indices]
        
#         # Retrieve codes for each position in order
#         retrieved_codes = []
        
#         for i, pos in enumerate(sorted_positions):
#             pos_item = pos.item()
#             if pos_item in self.index:
#                 # Add codes for this position
#                 codes_to_add = self.index[pos_item]
#                 retrieved_codes.extend(codes_to_add)
                
#                 # Check if we've reached the limit
#                 if len(retrieved_codes) >= max_codes:
#                     # retrieved_codes = retrieved_codes[:max_codes]  # Truncate to max_codes
#                     break
                    
#         return list(set(retrieved_codes))

import torch
from collections import defaultdict

class SparseCodeMatcher:
    def __init__(self, z=None):
        """
        Initialize the sparse code matcher.
        
        Args:
            z: Optional tensor of shape [num_codes, code_dim] containing sparse codes
        """
        self.index = defaultdict(lambda: torch.tensor([], dtype=torch.long))
        if z is not None:
            self.build_index(z)
    
    def build_index(self, z):
        """
        Build an index from sparse codes for efficient retrieval.
        
        Args:
            z: Tensor of shape [num_codes, code_dim] containing sparse codes
        """
        self.index = defaultdict(lambda: torch.tensor([], dtype=torch.long))
        
        # Ensure z is a torch tensor
        if not isinstance(z, torch.Tensor):
            z = torch.tensor(z, dtype=torch.float32)
            
        for i in range(z.shape[0]):
            # Find non-zero entries (torch version)
            nonzero_indices = torch.nonzero(z[i], as_tuple=True)[0]
            
            # Add this code's index to each entry's tensor
            for pos in nonzero_indices:
                pos_item = pos.item()
                # Append to existing tensor
                self.index[pos_item] = torch.cat([
                    self.index[pos_item], 
                    torch.tensor([i], dtype=torch.long)
                ])

        # # Create flattened arrays for efficient lookup
        # values = []
        # offsets = [0]
        # vocab_size = z.shape[0]
        # for i in range(vocab_size):
        #     codes = self.index.get(i, torch.empty(0, dtype=torch.long))
        #     values.append(codes)
        #     offsets.append(offsets[-1] + codes.numel())
        # self.code_values = torch.cat(values)                  # (total_codes,)
        # self.code_offsets = torch.tensor(offsets)
        # if self.code_values.device != self.code_offsets.device:
        #     self.code_offsets = self.code_offsets.to(self.code_values.device)
        # self.vocab_size = vocab_size
        
        return self

    def retrieve_similar_codes(
        self, query_vector: torch.Tensor, max_codes: int = 4000
    ) -> torch.Tensor:
        if not isinstance(query_vector, torch.Tensor):
            query_vector = torch.tensor(query_vector, dtype=torch.float32)
            
        # Get non-zero entries and their values
        nonzero_indices = torch.nonzero(query_vector, as_tuple=True)[0]
        values = query_vector[nonzero_indices]
        
        # Sort entries by value in descending order
        sorted_indices = torch.argsort(values, descending=True)
        sorted_positions = nonzero_indices[sorted_indices]
        
        # Retrieve codes for each position in order
        retrieved_codes = torch.tensor([], dtype=torch.long)
        
        for pos in sorted_positions:
            pos_item = pos.item()
            if pos_item in self.index:
                # Add codes for this position
                codes_to_add = self.index[pos_item]
                retrieved_codes = torch.cat([retrieved_codes, codes_to_add])
                
                # Check if we've reached the limit
                if retrieved_codes.size(0) >= max_codes:
                    break
        
        # Remove duplicates using unique
        unique_codes = torch.unique(retrieved_codes)
        
        # Truncate if necessary
        if unique_codes.size(0) > max_codes:
            unique_codes = unique_codes[:max_codes]
            
        return unique_codes

    # def retrieve_similar_codes(self, query_vector, max_codes=4000):
    #     """
    #     Retrieve codes similar to the query vector.
        
    #     Args:
    #         query_vector: Sparse query vector
    #         max_codes: Maximum number of codes to return
            
    #     Returns:
    #         List of code indices
    #     """
    #     # Get non-zero entries and their values
    #     nonzero_indices = torch.nonzero(query_vector, as_tuple=True)[0]
    #     values = query_vector[nonzero_indices]
        
    #     # Sort entries by value in descending order
    #     sorted_indices = torch.argsort(values, descending=True)
    #     sorted_positions = nonzero_indices[sorted_indices]
        
    #     # Retrieve codes for each position in order
    #     retrieved_codes = []
        
    #     for i, pos in enumerate(sorted_positions):
    #         pos_item = pos.item()
    #         if pos_item in self.index:
    #             # Add codes for this position
    #             codes_to_add = self.index[pos_item]
    #             retrieved_codes.extend(codes_to_add)
                
    #             # Check if we've reached the limit
    #             if len(retrieved_codes) >= max_codes:
    #                 # retrieved_codes = retrieved_codes[:max_codes]  # Truncate to max_codes
    #                 break
                    
    #     return list(set(retrieved_codes))
import torch
import torch.nn.functional as F

import re
def quarter_sentence(sentence):
    
    sentence = fix_punctuation(sentence)
    # Split sentence into words, keeping punctuation attached to words
    words = re.findall(r'\S+', sentence)
    quarter_length = len(words) // 2
    first_quarter = ' '.join(words[:quarter_length])
    return first_quarter

def fix_punctuation(text: str) -> str:
    # 0) remove spaces before apostrophes (e.g. megan 's → megan's)
    text = re.sub(r"\s+'", r"'", text)
    # 1) remove spaces before other punctuation
    text = re.sub(r'\s+([,.;:!?])', r'\1', text)
    # 2) ensure one space after punctuation (unless end-of-string)
    text = re.sub(r'([,.;:!?])(?=\S)', r'\1 ', text)
    return text

def compute_kl_divergence(logits1, logits2):
    """
    Compute KL divergence between two distributions represented by logits,
    handling -inf values in logits which result in zero probabilities.
    
    Args:
        logits1: First tensor of logits (unnormalized log probabilities)
        logits2: Second tensor of logits (unnormalized log probabilities)
        
    Returns:
        KL(P||Q) where P is the distribution from logits1 and Q from logits2.
        Returns inf if all logits in logits1 are -inf.
    """
    # Return inf if all logits are -inf
    if torch.all(torch.isinf(logits1) & (logits1 < 0)):
        return torch.tensor(float('inf'), device=logits1.device)
    
    # Convert logits to probabilities using softmax
    # Softmax naturally handles -inf by converting them to 0
    p = F.softmax(logits1, dim=-1)
    
    # Compute log probabilities carefully to handle zeros
    log_p = F.log_softmax(logits1, dim=-1)
    log_q = F.log_softmax(logits2, dim=-1)
    
    # In KL divergence, when p(x) = 0, the term p(x)*log(p(x)/q(x)) = 0
    # We need to handle this manually since log(p/q) involves log(0) when p=0
    # Compute KL only where p > 0
    kl_terms = torch.zeros_like(p)
    non_zero_p = p > 0
    
    # For positions where p > 0, compute the KL terms
    kl_terms[non_zero_p] = p[non_zero_p] * (log_p[non_zero_p] - log_q[non_zero_p])
    
    # Sum across the distribution dimension and average
    kl_div = kl_terms.sum(dim=-1).mean()
    
    return kl_div

def compute_top_k_overlap(logits1, logits2, k):
    """
    Compute the number of overlapping elements in the top-K predictions of two logit tensors.
    
    Args:
        logits1: First tensor of logits
        logits2: Second tensor of logits
        k: Number of top elements to consider
        
    Returns:
        Number of overlapping elements in the top-K of both tensors
    """
    # Get indices of top-k elements for both tensors
    _, top_k_indices1 = torch.topk(logits1, k, dim=-1)
    _, top_k_indices2 = torch.topk(logits2, k, dim=-1)
    
    # Handle case where the inputs have different batch dimensions
    if len(top_k_indices1.shape) > 1 and len(top_k_indices2.shape) > 1:
        # For batched inputs
        overlap_count = 0
        batch_size = top_k_indices1.size(0)
        
        for i in range(batch_size):
            # Convert to sets and find intersection
            set1 = set(top_k_indices1[i].tolist())
            set2 = set(top_k_indices2[i].tolist())
            overlap = len(set1.intersection(set2))
            overlap_count += overlap
            
        return overlap_count / batch_size  # Average overlap across batch
    else:
        # For single inputs (no batch dimension)
        set1 = set(top_k_indices1.tolist())
        set2 = set(top_k_indices2.tolist())
        return len(set1.intersection(set2))

def compute_total_variation_distance(logits1, logits2):
    """
    Compute total variation distance between two distributions represented by logits.
    Total variation distance is defined as 0.5 * sum(|p(x) - q(x)|) where p and q are the distributions.
    
    Args:
        logits1: First tensor of logits (unnormalized log probabilities)
        logits2: Second tensor of logits (unnormalized log probabilities)
        
    Returns:
        Total variation distance between the distributions from logits1 and logits2.
        Returns 1.0 if all logits in logits1 are -inf (maximum possible distance).
    """
    # Check for invalid inputs (all -inf)
    if torch.all(torch.isinf(logits1) & (logits1 < 0)) or torch.all(torch.isinf(logits2) & (logits2 < 0)):
        return torch.tensor(1.0, device=logits1.device)
    
    # Convert logits to probabilities using softmax
    # Softmax naturally handles -inf by converting them to 0
    p = F.softmax(logits1, dim=-1)
    q = F.softmax(logits2, dim=-1)
    
    # Compute L1 distance between distributions
    l1_distance = torch.abs(p - q).sum(dim=-1)
    
    # Total variation distance is half the L1 distance
    tv_distance = 0.5 * l1_distance.mean()
    
    return tv_distance
